# This test generates all variants of wmma intrinsics and verifies that LLVM
# generates correct instructions for them. This is the test generator only.  The
# test scripts themselves are in wmma-ptx*-sm*.py files.

# RUN: true

from __future__ import print_function

import argparse
from itertools import product
from string import Template

class MMAType:
    def __init__(self, ptx_type):
        self.ptx_type = ptx_type
        self.llvm_type = {
            "f16": "<2 x half>",
            "f32": "float",
            "f64": "double",
            "s32": "i32",
            "b16": "i32",
            "s8": "i32",
            "u8": "i32",
            "s4": "i32",
            "u4": "i32",
            "b1": "i32",
            "bf16": "i32",
            "tf32": "i32",
        }[ptx_type]

        self.ptx_reg_pattern = {
            "f16": "%r[0-9]+",
            "f32": "%f[0-9]+",
            "f64": "%fd[0-9]+",
        }.get(ptx_type, "%r[0-9]+")

    def __repr__(self):
        return "%s/%s" % (self.ptx_type, self.llvm_type)


class MMAFrag:
    def __init__(self, geom, frag, ptx_elt_type):
        self.geom = geom
        self.frag = frag
        self.mma_type = MMAType(ptx_elt_type)
        self.nregs = {
            # u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
            "m16n16k16:a:u8": 2,
            "m16n16k16:a:s8": 2,
            "m16n16k16:b:u8": 2,
            "m16n16k16:b:s8": 2,
            "m16n16k16:c:s32": 8,
            "m16n16k16:d:s32": 8,
            "m8n32k16:a:u8": 1,
            "m8n32k16:a:s8": 1,
            "m8n32k16:b:u8": 4,
            "m8n32k16:b:s8": 4,
            "m8n32k16:c:s32": 8,
            "m8n32k16:d:s32": 8,
            "m32n8k16:a:u8": 4,
            "m32n8k16:a:s8": 4,
            "m32n8k16:b:u8": 1,
            "m32n8k16:b:s8": 1,
            "m32n8k16:c:s32": 8,
            "m32n8k16:d:s32": 8,
            "m8n8k16:a:u8": 1,
            "m8n8k16:a:s8": 1,
            "m8n8k16:b:u8": 1,
            "m8n8k16:b:s8": 1,
            "m8n8k16:c:s32": 2,
            "m8n8k16:d:s32": 2,
            "m16n8k16:a:u8": 2,
            "m16n8k16:a:s8": 2,
            "m16n8k16:b:u8": 1,
            "m16n8k16:b:s8": 1,
            "m16n8k16:c:s32": 4,
            "m16n8k16:d:s32": 4,
            "m16n8k32:a:u8": 4,
            "m16n8k32:a:s8": 4,
            "m16n8k32:b:u8": 2,
            "m16n8k32:b:s8": 2,
            "m16n8k32:c:s32": 4,
            "m16n8k32:d:s32": 4,
            # u4/s4 -> s32 @ m8n8k32 (u4/s4)
            "m8n8k32:a:u4": 1,
            "m8n8k32:a:s4": 1,
            "m8n8k32:b:u4": 1,
            "m8n8k32:b:s4": 1,
            "m8n8k32:c:s32": 2,
            "m8n8k32:d:s32": 2,
            "m16n8k32:a:u4": 2,
            "m16n8k32:a:s4": 2,
            "m16n8k32:b:u4": 1,
            "m16n8k32:b:s4": 1,
            "m16n8k32:c:s32": 4,
            "m16n8k32:d:s32": 4,
            "m16n8k64:a:u4": 4,
            "m16n8k64:a:s4": 4,
            "m16n8k64:b:u4": 2,
            "m16n8k64:b:s4": 2,
            "m16n8k64:c:s32": 4,
            "m16n8k64:d:s32": 4,
            # b1 -> s32 @ m8n8k128(b1)
            "m8n8k128:a:b1": 1,
            "m8n8k128:b:b1": 1,
            "m8n8k128:c:s32": 2,
            "m8n8k128:d:s32": 2,
            "m16n8k128:a:b1": 2,
            "m16n8k128:b:b1": 1,
            "m16n8k128:c:s32": 4,
            "m16n8k128:d:s32": 4,
            "m16n8k256:a:b1": 4,
            "m16n8k256:b:b1": 2,
            "m16n8k256:c:s32": 4,
            "m16n8k256:d:s32": 4,
            # bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
            "m16n16k16:a:bf16": 4,
            "m16n16k16:b:bf16": 4,
            "m8n32k16:a:bf16": 2,
            "m8n32k16:b:bf16": 8,
            "m32n8k16:a:bf16": 8,
            "m32n8k16:b:bf16": 2,
            "m16n8k16:a:bf16": 4,
            "m16n8k16:b:bf16": 2,
            "m16n8k16:c:f32": 4,
            "m16n8k16:d:f32": 4,
            "m16n8k8:a:bf16": 2,
            "m16n8k8:b:bf16": 1,
            "m16n8k8:c:f32": 4,
            "m16n8k8:d:f32": 4,
            "m8n8k4:a:f64": 1,
            "m8n8k4:b:f64": 1,
            "m8n8k4:c:f64": 2,
            "m8n8k4:d:f64": 2,
            # tf32 -> s32 @ m16n16k8
            "m16n16k8:a:tf32": 4,
            "m16n16k8:b:tf32": 4,
            "m16n8k4:a:tf32": 2,
            "m16n8k4:b:tf32": 1,
            "m16n8k4:c:f32": 4,
            "m16n8k4:d:f32": 4,
            "m16n8k8:a:tf32": 4,
            "m16n8k8:b:tf32": 2,
            "m16n8k8:c:f32": 4,
            "m16n8k8:d:f32": 4,
            "m8n8k4:a:f16": 2,
            "m8n8k4:b:f16": 2,
            "m16n8k8:a:f16": 2,
            "m16n8k8:b:f16": 1,
            "m16n8k8:c:f16": 2,
            "m16n8k8:d:f16": 2,
            "m16n8k8:c:f32": 4,
            "m16n8k8:d:f32": 4,
            "m16n8k16:a:f16": 4,
            "m16n8k16:b:f16": 2,
            "m16n8k16:c:f16": 2,
            "m16n8k16:d:f16": 2,
            "m16n8k16:c:f32": 4,
            "m16n8k16:d:f32": 4,
            # ldmatrix
            "m8n8:x1:b16": 1,
            "m8n8:x2:b16": 2,
            "m8n8:x4:b16": 4,
        }.get(
            "%s:%s:%s" % (geom, frag, ptx_elt_type),
            {
                # All other FP shape/fragment/type combinations have the same size
                "a:f16": 8,
                "b:f16": 8,
                "c:f16": 4,
                "d:f16": 4,
                "c:f32": 8,
                "d:f32": 8,
            }.get("%s:%s" % (frag, ptx_elt_type), None),
        )
        assert self.nregs

    def __repr__(self):
        return "%s:%s:%s%s" % (
            self.geom,
            self.frag,
            self.mma_type,
            "" if self.nregs == 1 else ("*%d" % self.nregs),
        )


class MMAOp:
    def __init__(self, a, b, c, d):
        self.a = a
        self.b = b
        self.c = c
        self.d = d

    def __repr__(self):
        return "{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d)


def make_mma_ops(geoms, types_a, types_b, types_c, types_d):
    ops = []
    for geom, type_a, type_c in product(geoms, types_a, types_c):
        for type_b, type_d in product(
            types_b if types_b else [type_a], types_d if types_d else [type_c]
        ):
            ops.append(
                MMAOp(
                    MMAFrag(geom, "a", type_a),
                    MMAFrag(geom, "b", type_b),
                    MMAFrag(geom, "c", type_c),
                    MMAFrag(geom, "d", type_d),
                )
            )
    return ops


def make_ldst_ops(geoms, frags, types):
    return [
        MMAFrag(geom, frag, ptx_type)
        for (geom, frag, ptx_type) in product(geoms, frags, types)
    ]


def make_ldmatrix_ops(geoms, frags, types):
    return [
        MMAFrag(geom, frag, ptx_type)
        for (geom, frag, ptx_type) in product(geoms, frags, types)
    ]


def get_wmma_ops():
    return (
        make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], [])
        + make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], ["bf16"], [], ["f32"], [])
        + make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], [])
        + make_mma_ops(
            ["m16n16k16", "m32n8k16", "m8n32k16"],
            ["f16"],
            [],
            ["f16", "f32"],
            ["f16", "f32"],
        )
        + make_mma_ops(
            ["m16n16k16", "m32n8k16", "m8n32k16"], ["s8", "u8"], [], ["s32"], []
        )
        + make_mma_ops(["m8n8k32"], ["s4", "u4"], [], ["s32"], [])
        + make_mma_ops(["m8n8k128"], ["b1"], [], ["s32"], [])
    )


def get_mma_ops():
    return (
        make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], [])
        + make_mma_ops(["m16n8k4", "m16n8k8"], ["tf32"], [], ["f32"], [])
        + make_mma_ops(["m16n8k16", "m16n8k8"], ["bf16"], [], ["f32"], [])
        + make_mma_ops(
            ["m8n8k4", "m16n8k8", "m16n8k16"],
            ["f16"],
            [],
            ["f16", "f32"],
            ["f16", "f32"],
        )
        + make_mma_ops(
            ["m8n8k16", "m16n8k16", "m16n8k32"], ["s8", "u8"], ["s8", "u8"], ["s32"], []
        )
        + make_mma_ops(
            ["m8n8k32", "m16n8k32", "m16n8k64"], ["s4", "u4"], ["s4", "u4"], ["s32"], []
        )
        + make_mma_ops(["m8n8k128", "m16n8k128", "m16n8k256"], ["b1"], [], ["s32"], [])
    )


def get_ldst_ops(kind):
    ldst_ops = (
        make_ldst_ops(
            ["m16n16k16", "m32n8k16", "m8n32k16"],
            ["a", "b"],
            ["f16", "u8", "s8", "bf16"],
        )
        + make_ldst_ops(
            ["m16n16k16", "m32n8k16", "m8n32k16"], ["c", "d"], ["f16", "f32", "s32"]
        )
        + make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4", "u4"])
        + make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"])
        + make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"])
        + make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"])
        + make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"])
        + make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"])
    )
    return [x for x in ldst_ops if (x.frag == "d") == (kind == "store")]


def get_ldmatrix_ops():
    return make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])


def is_wmma_geom_supported(geom):
    # geometries for FP and ints.
    if geom in ["m8n32k16", "m32n8k16"]:
        return ptx_version >= 61
    # geometries for sub-ints.
    if geom in ["m8n8k32", "m8n8k128"]:
        return ptx_version >= 63 and gpu_arch >= 75
    if geom == "m16n16k16":
        return ptx_version >= 60
    if geom == "m16n8k8":
        return ptx_version >= 65
    if geom in ["m16n16k8", "m8n8k4"]:
        return ptx_version >= 70
    assert False  # Unexpected geometry.


def is_mma_geom_supported(geom):
    # geometries for FP and ints.
    if geom == "m8n8k4":
        return ptx_version >= 64
    if geom in ["m16n8k8", "m8n8k16", "m8n8k32"]:
        return ptx_version >= 65
    if geom in [
        "m16n8k16",
        "m16n8k4",
        "m16n8k32",
        "m16n8k64",
        "m8n8k128",
        "m16n8k128",
        "m16n8k256",
    ]:
        return ptx_version >= 70
    assert False  # Unexpected geometry.


def is_ldmatrix_geom_supported(geom):
    if geom in ["m8n8"]:
        return ptx_version >= 65 and gpu_arch >= 75
    assert False  # Unexpected geometry.


def is_type_supported(ptx_type):
    if ptx_type in ["s8", "u8", "s32"]:
        return ptx_version >= 63 and gpu_arch >= 72
    if ptx_type in ["s4", "u4", "b1"]:
        return ptx_version >= 63 and gpu_arch >= 75
    if ptx_type == "b16":
        return ptx_version >= 65 and gpu_arch >= 75
    if ptx_type in ["bf16", "tf32", "f64"]:
        return ptx_version >= 70
    return ptx_version >= 60 and gpu_arch >= 70


def is_wmma_variant_supported(op, layout_a, layout_b, rnd, satf):
    if not (
        is_type_supported(op.a.mma_type.ptx_type) and is_wmma_geom_supported(op.a.geom)
    ):
        return False

    # rnd is only supported for FP64 WMMA
    if rnd and op.a.mma_type.ptx_type != "f64":
        return False

    if satf:
        # satfinite for floating points was removed in PTX 6.5
        if op.a.mma_type.ptx_type == "f16" and ptx_version >= 65:
            return False
        if not op.a.mma_type.ptx_type in ["f16", "s8", "u8", "s4", "u4"]:
            return False

    # sub-integer require row/col layout.
    if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]:
        return layout_a == "row" and layout_b == "col"
    return True


def is_mma_variant_supported(op, layout_a, layout_b, satf):
    if not (
        is_type_supported(op.a.mma_type.ptx_type) and is_mma_geom_supported(op.a.geom)
    ):
        return False

    if satf and not op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4"]:
        return False

    # If the type of C is f32 then so must the type of D
    if (
        op.a.geom == "m8n8k4"
        and op.c.mma_type.ptx_type == "f32"
        and op.d.mma_type.ptx_type != "f32"
    ):
        return False

    # A and B type must be the same. C and D type must be the same
    if op.a.geom == "m16n8k8" and (
        op.a.mma_type.ptx_type != op.b.mma_type.ptx_type
        or op.c.mma_type.ptx_type != op.d.mma_type.ptx_type
    ):
        return False

    # C and D type must be the same
    if op.a.geom == "m16n8k16" and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type:
        return False

    # Require row/col layout for all MMA except m8n8k4 on FP16
    if not (op.a.geom == "m8n8k4" and op.a.mma_type.ptx_type == "f16"):
        return layout_a == "row" and layout_b == "col"
    return True


def is_ldst_variant_supported(frag, layout):
    if not (
        is_type_supported(frag.mma_type.ptx_type) and is_wmma_geom_supported(frag.geom)
    ):
        return False
    if frag.mma_type.ptx_type in ["s4", "u4", "b1"]:
        # sub-integer require sm_75 and ptx63, row/col layout for a/b.
        return (
            (frag.frag == "a" and layout == "row")
            or (frag.frag == "b" and layout == "col")
            or frag.frag in ["c", "d"]
        )
    return True


def is_ldmatrix_variant_supported(frag):
    if not (
        is_type_supported(frag.mma_type.ptx_type)
        and is_ldmatrix_geom_supported(frag.geom)
    ):
        return False
    return frag.frag in ["x1", "x2", "x4"]


def make_wmma_slice_ty(frag):
    return [frag.mma_type.llvm_type] * frag.nregs


def make_wmma_ld_ret_ty(frag):
    results = make_wmma_slice_ty(frag)
    if len(results) == 1:
        return "%s" % results[0]
    return "{%s}" % ", ".join(results)


# returns address space
def get_aspace(space):
    space_map = {
        ".global": 1,
        ".shared": 3,
        ".const": 4,
        ".local": 5,
        ".param": 101,
        "": 0,
        ".generic": 0,
    }
    return space_map[space]


def get_pspace(space):
    return "p%di8" % get_aspace(space)


def check_pattern(frag):
    return "{{%s}}" % ", *".join([frag.mma_type.ptx_reg_pattern] * frag.nregs)


def gen_wmma_load_tests():
    load_template = """
declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});

; CHECK-LABEL: .func {{.*}}test_${function}(
define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) {
; CHECK: ${instruction}
; CHECK: {${check_result}}
; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
  %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
  ret ${ret_ty} %v0;
}

; CHECK-LABEL: .func{{.*}}test_${function}_o(
define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) {
; CHECK: ${instruction}
; CHECK: {${check_result}}
; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
  %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
  %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args});
  ret ${ret_ty} %v0;
}
"""
    intrinsic_template = (
        "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
    )
    instruction_template = (
        "wmma.load.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
    )

    generated_items = []

    for frag, layout, space, stride in product(
        get_ldst_ops("load"),
        ["row", "col"],
        ["", ".shared", ".global"],
        ["", ".stride"],
    ):
        if not is_ldst_variant_supported(frag, layout):
            continue

        params = {
            "abc": frag.frag,
            "aligned": ".aligned" if ptx_version >= 63 else "",
            "layout": layout,
            "space": space,
            "stride": stride,
            "itype": frag.mma_type.ptx_type,
            "pspace": get_pspace(space),
            "as": "addrspace(%d)" % get_aspace(space),
            "geom": frag.geom,
        }

        test_params = params
        test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
        test_params["function"] = test_params["intrinsic"].replace(".", "_")
        test_params["instruction"] = Template(instruction_template).substitute(params)
        test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
        test_params["check_result"] = check_pattern(frag)

        if stride:
            test_params["extra_args"] = ", i32 %stride"
            test_params["stride_pattern"] = ", %r{{[0-9]+}}"
        else:
            test_params["extra_args"] = ""
            test_params["stride_pattern"] = ""

        print(Template(load_template).substitute(test_params))

        generated_items.append((test_params["intrinsic"], test_params["instruction"]))

    return generated_items


def make_wmma_slice_args(frag):
    return ", ".join(
        [
            "%s %%%s%d" % (t, frag.frag, i)
            for i, t in enumerate(make_wmma_slice_ty(frag))
        ]
    )


def gen_wmma_store_tests():
    store_template = """
declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args});

; CHECK-LABEL: .func {{.*}}test_${function}(
define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) {
; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}
; CHECK: {${check_args}}
; CHECK: ${stride_pattern}
  call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args});
  ret void
}

; CHECK-LABEL: .func{{.*}}test_${function}_o(
define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
; CHECK: ${check_args}
; CHECK: ${stride_pattern}
  %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
  call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args});
  ret void
}
"""
    intrinsic_template = (
        "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
    )
    instruction_template = (
        "wmma.store.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
    )

    generated_items = []

    for frag, layout, space, stride in product(
        get_ldst_ops("store"),
        ["row", "col"],
        ["", ".shared", ".global"],
        ["", ".stride"],
    ):

        if not is_ldst_variant_supported(frag, layout):
            continue

        params = {
            "abc": frag.frag,
            "aligned": ".aligned" if ptx_version >= 63 else "",
            "layout": layout,
            "space": space,
            "stride": stride,
            "itype": frag.mma_type.ptx_type,
            "pspace": get_pspace(space),
            "as": "addrspace(%d)" % get_aspace(space),
            "geom": frag.geom,
        }

        test_params = params
        test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
        test_params["function"] = test_params["intrinsic"].replace(".", "_")
        test_params["instruction"] = Template(instruction_template).substitute(params)
        test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
        test_params["check_args"] = check_pattern(frag)
        if stride:
            test_params["extra_args"] = ", i32 %stride"
            test_params["stride_pattern"] = ", %r{{[0-9]+}};"
        else:
            test_params["extra_args"] = ""
            test_params["stride_pattern"] = ";"
        test_params["args"] = make_wmma_slice_args(frag)

        print(Template(store_template).substitute(test_params))
        generated_items.append((test_params["intrinsic"], test_params["instruction"]))

    return generated_items


def gen_ldmatrix_tests():
    ldmatrix_template = """
declare ${ret_ty} @${intrinsic}(i8 ${as}* %src);

; CHECK-LABEL: .func {{.*}}test_${function}(
define ${ret_ty} @test_${function}(i8 ${as}* %src) {
; CHECK: ${instruction}
; CHECK: {${check_result}}
; CHECK: [%rd{{[0-9]+}}]
  %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src);
  ret ${ret_ty} %v0;
}

; CHECK-LABEL: .func{{.*}}test_${function}_o(
define ${ret_ty} @test_${function}_o(i8 ${as}* %src) {
; CHECK: ${instruction}
; CHECK: {${check_result}}
; CHECK: [%rd{{[0-9]+}}+128]
  %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
  %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1);
  ret ${ret_ty} %v0;
}
"""
    intrinsic_template = (
        "llvm.nvvm.ldmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
    )
    instruction_template = (
        "ldmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
    )

    generated_items = []

    for frag, space, trans in product(
        get_ldmatrix_ops(),
        ["", ".shared"],
        ["", ".trans"],
    ):
        if not is_ldmatrix_variant_supported(frag):
            continue

        params = {
            "frag": frag.frag,
            "space": space,
            "trans": trans,
            "itype": frag.mma_type.ptx_type,
            "pspace": get_pspace(space),
            "as": "addrspace(%d)" % get_aspace(space),
            "geom": frag.geom,
        }

        test_params = params
        test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
        test_params["function"] = test_params["intrinsic"].replace(".", "_")
        test_params["instruction"] = Template(instruction_template).substitute(params)
        test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
        test_params["check_result"] = check_pattern(frag)

        print(Template(ldmatrix_template).substitute(test_params))

        generated_items.append((test_params["intrinsic"], test_params["instruction"]))

    return generated_items


def mma_signature(op):
    if op.a.mma_type.ptx_type == "f16":
        # FP16 ops identified by accumulator & result type.
        return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
    elif op.a.mma_type.ptx_type != op.b.mma_type.ptx_type:
        # other ops are identified by input types.
        return "%s.%s" % (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type)
    else:
        # if input types are the same, it only appears once.
        return op.a.mma_type.ptx_type


def mma_ptx_signature(op):
    # Encode all four types as D.A.B.C
    return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))


def wmma_signature(op):
    if op.a.mma_type.ptx_type == "f16":
        # FP16 ops identified by accumulator & result type.
        return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
    else:
        # other ops are identified by input type.
        return op.a.mma_type.ptx_type


def wmma_ptx_signature(op):
    if op.a.mma_type.ptx_type == "f16":
        # FP16 instructions use D.C
        return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
    else:
        # other instructions encode all four types as D.A.B.C
        return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))


def common_mma_test_gen(params, op, intrinsic_template, instruction_template):
    mma_template = """
declare ${ret_ty} @${intrinsic}(
        ${args});

; CHECK-LABEL: .func {{.*}}test_${function}(
define ${ret_ty} @test_${function}(
        ${args}) {
; CHECK: ${instruction}
; CHECK-NEXT: ${check_d}
; CHECK-NEXT: ${check_a}
; CHECK-NEXT: ${check_b}
; CHECK-NEXT: ${check_c}
  %r = call ${ret_ty} @${intrinsic}(
        ${args});
  ret ${ret_ty} %r;
}
"""

    test_params = params
    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
    test_params["function"] = test_params["intrinsic"].replace(".", "_")
    test_params["instruction"] = Template(instruction_template).substitute(params)
    test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
    test_params["check_a"] = check_pattern(op.a)
    test_params["check_b"] = check_pattern(op.b)
    test_params["check_c"] = check_pattern(op.c)
    test_params["check_d"] = check_pattern(op.d)
    args = ",\n        ".join(make_wmma_slice_args(frag) for frag in (op.a, op.b, op.c))
    test_params["args"] = args
    print(Template(mma_template).substitute(test_params))
    return (test_params["intrinsic"], test_params["instruction"])


def get_b1_ops(ptx_type):
    if ptx_type != "b1":
        return [""]
    if ptx_version >= 71:
        return [".xor.popc", ".and.popc"]
    return [".xor.popc"]


def gen_wmma_mma_tests():
    wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma${b1op}.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}"
    wmma_instruction_template = "wmma.mma${b1op}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}"

    generated_items = []

    for op, alayout, blayout, rnd, satf in product(
        get_wmma_ops(),
        ["row", "col"],
        ["row", "col"],
        [".rn", ".rz", ".rm", ".rp", ""],
        [".satfinite", ""],
    ):

        if not is_wmma_variant_supported(op, alayout, blayout, rnd, satf):
            continue

        for b1op in get_b1_ops(op.a.mma_type.ptx_type):
            params = {
                "aligned": ".aligned" if ptx_version >= 63 else "",
                "alayout": alayout,
                "blayout": blayout,
                "intrinsic_signature": wmma_signature(op),
                "ptx_signature": wmma_ptx_signature(op),
                "satf": satf,
                "rnd": rnd,
                "geom": op.a.geom,
                "b1op": b1op,
            }

            intrinsic_template = wmma_intrinsic_template
            instruction_template = wmma_instruction_template

            generated_items.append(
                common_mma_test_gen(
                    params, op, intrinsic_template, instruction_template
                )
            )

    return generated_items


def gen_mma_tests():
    mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
    mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${b1op}"

    generated_items = []

    for op, alayout, blayout, satf in product(
        get_mma_ops(), ["row", "col"], ["row", "col"], [".satfinite", ""]
    ):

        if not is_mma_variant_supported(op, alayout, blayout, satf):
            continue

        for b1op in get_b1_ops(op.a.mma_type.ptx_type):
            params = {
                "aligned": ".aligned" if ptx_version >= 63 else "",
                "alayout": alayout,
                "blayout": blayout,
                "intrinsic_signature": mma_signature(op),
                "ptx_signature": mma_ptx_signature(op),
                "satf": satf,
                "geom": op.a.geom,
                "b1op": b1op,
            }

            intrinsic_template = mma_intrinsic_template
            instruction_template = mma_instruction_template

            generated_items.append(
                common_mma_test_gen(
                    params, op, intrinsic_template, instruction_template
                )
            )

    return generated_items


# Append complete list of intrinsics and instructions we've generated tests for.
# Generate set of checks to verify that that we did generate sensible set of
# tests for the given combination of PTX and SM variants.
#
def gen_check_unsupported_ops(items):
    print(
        "; Complete list of intrinsics supported by PTX%d on sm_%d"
        % (ptx_version, gpu_arch)
    )
    print("; INTRINSICS: {{^; INTRINSICS_LIST_BEGIN}}")
    print(
        """

; NOEXTGEOM-NOT: {{m8n32|m32n8}}
; NOINT-NOT: .{{s32|s8}}
; NOSUBINT-NOT: {{s4|u4|b1}}
; NOMMA-NOT: .m8n8k4.
; NOALTFLOAT-NOT: .{{bf16|tf32}}
; NODOUBLE-NOT: .f64
; NOLDMATRIX-NOT: ldmatrix.sync.aligned

; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f32
; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f16
; M16N16-DAG: m16n16k16.mma.{{.*}}.f16.f16
; M16N16-DAG: m16n16k16.mma.{{.*}}.f32.f32

; PTX60 adds support for m32n8k16/m8n32k16 geometries.
; EXTGEOM-DAG: m32n8k16.load.{{[ab].*}}.f16.p
; EXTGEOM-DAG: m32n8k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f32
; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f16
; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f16.f16
; EXTGEOM-DAG: m32n8k16.mma.{{.*}}.f32.f32

; EXTGEOM-DAG: m8n32k16.load.{{[ab].*}}.f16.p
; EXTGEOM-DAG: m8n32k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f32
; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f16
; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f16.f16
; EXTGEOM-DAG: m8n32k16.mma.{{.*}}.f32.f32

; INT-DAG: m16n16k16.load.{{[ab].*}}.s8.p
; INT-DAG: m8n32k16.load.{{[ab].*}}.s8.p
; INT-DAG: m32n8k16.load.{{[ab].*}}.s8.p
; INT-DAG: m16n16k16.load.{{[ab].*}}.u8.p
; INT-DAG: m8n32k16.load.{{[ab].*}}.u8.p
; INT-DAG: m32n8k16.load.{{[ab].*}}.u8.p
; INT-DAG: m32n8k16.{{load|store}}.{{[cd].*\.s32}}.p
; INT-DAG: m16n16k16.mma.{{.*}}.u8
; INT-DAG: m16n16k16.mma.{{.*}}.s8
; INT-DAG: m8n32k16.mma.{{.*}}.u8
; INT-DAG: m8n32k16.mma.{{.*}}.s8
; INT-DAG: m32n8k16.mma.{{.*}}.u8
; INT-DAG: m32n8k16.mma.{{.*}}.s8

; SUBINT-DAG: m8n8k128.load.{{[ab].*}}.b1.p
; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.s4.p
; SUBINT-DAG: m8n8k32.load.{{[ab].*}}.u4.p
; SUBINT-DAG: m8n8k128.{{load|store}}.{{[cd].*\.s32}}.p
; SUBINT-DAG: m8n8k32.{{load|store}}.{{[cd].*\.s32}}.p
; SUBINT-DAG: m8n8k32.mma.{{.*}}.u4
; SUBINT-DAG: m8n8k32.mma.{{.*}}.s4
; SUBINT-DAG: m8n8k128.mma.{{.*}}.b1

; ALTFLOAT-DAG: m16n16k16.load.{{[ab].*}}.bf16.p
; ALTFLOAT-DAG: m8n32k16.load.{{[ab].*}}.bf16.p
; ALTFLOAT-DAG: m32n8k16.load.{{[ab].*}}.bf16.p
; ALTFLOAT-DAG: m16n16k8.load.{{[ab].*}}.tf32.p
; ALTFLOAT-DAG: m16n16k16.mma.{{.*}}.bf16
; ALTFLOAT-DAG: m8n32k16.mma.{{.*}}.bf16
; ALTFLOAT-DAG: m32n8k16.mma.{{.*}}.bf16
; ALTFLOAT-DAG: m16n16k8.mma.{{.*}}.tf32

; DOUBLE-DAG: m8n8k4.load.{{[abc].*}}.f64.p
; DOUBLE-DAG: m8n8k4.store.d.{{.*}}.f64.p
; DOUBLE-DAG: m8n8k4.mma.{{.*}}.f64

; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f32
; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f16
; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f16
; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f32

; PTX65MMA-DAG: mma.m16n8k8.row.col.f16.f16
; PTX65MMA-DAG: mma.m16n8k8.row.col.f32.f32
; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.u8
; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.s8
; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.u8
; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.s8
; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.u4
; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.s4
; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4
; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4

; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.b16
; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.b16
; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.b16
; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.b16
; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.b16
; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.b16
; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.shared.b16
; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.shared.b16
; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.shared.b16
; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16
; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16
; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16

; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
; PTX71MMA-DAG: mma.m16n8k16.row.col.bf16
; PTX71MMA-DAG: mma.m16n8k8.row.col.bf16
; PTX71MMA-DAG: mma.m16n8k16.row.col.f16.f16
; PTX71MMA-DAG: mma.m16n8k16.row.col.f32.f32
; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8
; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8
; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8
; PTX71MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8
; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8
; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8
; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8
; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8
; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4
; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4
; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4
; PTX71MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4
; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4
; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4
; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4
; PTX71MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4
; PTX71MMA-DAG: mma.and.popc.m8n8k128.row.col.b1
; PTX71MMA-DAG: mma.xor.popc.m8n8k128.row.col.b1
; PTX71MMA-DAG: mma.and.popc.m16n8k128.row.col.b1
; PTX71MMA-DAG: mma.xor.popc.m16n8k128.row.col.b1
; PTX71MMA-DAG: mma.and.popc.m16n8k256.row.col.b1
; PTX71MMA-DAG: mma.xor.popc.m16n8k256.row.col.b1
;

"""
    )

    print("; INTRINSICS_LIST_BEGIN")
    for intrinsic, instruction in sorted(items):
        print("; ", intrinsic, " -> ", instruction, "")
    print("; INTRINSICS_LIST_END")
    print("; INTRINSICS: ; INTRINSICS_LIST_END")


def gen_tests():
    items = gen_wmma_load_tests()
    items += gen_wmma_store_tests()
    items += gen_ldmatrix_tests()
    items += gen_wmma_mma_tests()
    items += gen_mma_tests()
    gen_check_unsupported_ops(items)


def main():
    global ptx_version
    global gpu_arch
    parser = argparse.ArgumentParser()
    parser.add_argument("--ptx", type=int, default=60)
    parser.add_argument("--gpu-arch", type=int, default=70)
    args = parser.parse_args()

    ptx_version = args.ptx
    gpu_arch = args.gpu_arch

    gen_tests()


if __name__ == "__main__":
    main()
